r"""
@Doc      : Implement of the NeuralLog, Pytorch version
@Author   : Zedong Jia
@Edit Date: 2024/4/1
"""

from transformers import BertTokenizer, BertModel
from torch import nn
import torch
from drain3.file_persistence import FilePersistence
from drain3.template_miner import TemplateMiner
from drain3.template_miner_config import TemplateMinerConfig
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
)
from sklearn.model_selection import train_test_split
from typing import *
import os


class NeuralLog(nn.Module):
    def __init__(self, max_len, d_model, nhead, d_ff, dropout, device) -> None:
        super().__init__()
        self.position_tensor = self._create_position_tensor(max_len, d_model).to(device)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, d_ff, dropout, batch_first=True
        )
        self.globalAvgPooling1D = nn.AdaptiveAvgPool1d(1)
        self.dropout1 = nn.Dropout(dropout)
        self.fc1 = nn.Sequential(nn.Linear(d_model, 32), nn.ReLU())
        self.dropout2 = nn.Dropout(dropout)
        self.fc2 = nn.Sequential(nn.Linear(32, 2), nn.ReLU())

    def _create_position_tensor(self, max_len, d_model):
        angle_rads = np.arange(max_len).reshape(-1, 1) / np.power(
            10000, (2 * (np.arange(d_model).reshape(1, -1) // 2)) / np.float32(d_model)
        )
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        return torch.tensor(angle_rads, dtype=torch.float32).reshape(
            1, max_len, d_model
        )

    def _position_embedding(self, seq_len):
        return self.position_tensor[:, :seq_len, :]

    def forward(self, inputs):
        seq_len = inputs.size(dim=1)
        inputs += self._position_embedding(seq_len)
        inputs = self.encoder_layer(inputs)
        hiddens = self.globalAvgPooling1D(inputs.transpose(1, 2)).squeeze(dim=2)
        hiddens = self.dropout1(hiddens)
        hiddens = self.fc1(hiddens)
        hiddens = self.dropout2(hiddens)
        outputs = self.fc2(hiddens)
        return outputs


class BertEncoder:
    def __init__(self, config) -> None:
        self._bert_tokenizer = BertTokenizer.from_pretrained(config["tokenizer_path"])
        self._bert_model = BertModel.from_pretrained(config["model_path"])
        self.cache = {}

    def __call__(self, sentence, no_wordpiece=False):
        r"""
        return list(len=768)
        """
        if self.cache.get(sentence, None) is None:
            if no_wordpiece:
                words = sentence.split(" ")
                words = [
                    word for word in words if word in self._bert_tokenizer.vocab.keys()
                ]
                sentence = " ".join(words)
            inputs = self._bert_tokenizer(
                sentence, truncation=True, return_tensors="pt", max_length=512
            )
            outputs = self._bert_model(**inputs)

            embedding = torch.mean(outputs.last_hidden_state, dim=1).squeeze(dim=1)
            self.cache[sentence] = embedding[0].tolist()
            return embedding[0].tolist()
        else:
            return self.cache[sentence]


class DrainProcesser:
    def __init__(self, config) -> None:
        r"""
        config: {
            "save_path": "path/to",
            "drain_config_path": "path/to"
        }
        """
        self._drain_config_path = config["drain_config_path"]
        persistence = FilePersistence(config["drain_save_path"])
        miner_config = TemplateMinerConfig()
        miner_config.load(config["drain_config_path"])
        self._template_miner = TemplateMiner(persistence, config=miner_config)

    def __call__(self, sentence) -> str:
        line = str(sentence).strip()
        result = self._template_miner.add_log_message(line)
        return result["template_mined"]


class LogDataset:
    def __init__(self, config) -> None:
        r"""
        X: [sample_num, seq, n_model]
        """
        self._X = []
        self._y = []
        self._config = config
        self._drain = DrainProcesser(config["drain_config"])
        self._encoder = BertEncoder(config["bert_config"])

    def add_sample(self, logs, label):
        seq = []
        templates = []
        for log in logs:
            template = self._drain(log)
            embedding = self._encoder(template)
            seq.append(embedding)
            templates.append(template)
        while len(seq) < self._config["max_len"]:
            seq.append([0] * 768)

        self._X.append(seq[: self._config["max_len"]])
        self._y.append(label)
        return templates

    def get_loader(self, is_training=True):
        if is_training:
            X_train, X_test, y_train, y_test = train_test_split(
                self._X, self._y, test_size=self._config["test_size"], stratify=self._y
            )

            X_train = torch.tensor(X_train, dtype=torch.float32)
            y_train = torch.tensor(y_train, dtype=torch.long)
            X_test = torch.tensor(X_test, dtype=torch.float32)
            y_test = torch.tensor(y_test, dtype=torch.long)
            train_loader = DataLoader(
                list(zip(X_train, y_train)), self._config["batch_size"]
            )
            test_loader = DataLoader(
                list(zip(X_test, y_test)), self._config["batch_size"]
            )
        else:
            X = torch.tensor(self._X, dtype=torch.float32)
            y = torch.tensor(self._y, dtype=torch.long)
            train_loader = None
            test_loader = DataLoader(list(zip(X, y)), self._config["batch_size"])
        return train_loader, test_loader


class Processer:
    def __init__(self, config) -> None:
        if not os.path.exists(config["result_dir"]):
            os.mkdir(config["result_dir"])
        self.__config__ = config
        # dataset
        self.__win_size__ = config["win_size"]
        self.__dataset__ = LogDataset(config)
        self.__wins_labels__ = []

    def __create_model__(self):
        return NeuralLog(
            self.__config__["max_len"],
            self.__config__["d_model"],
            self.__config__["nhead"],
            self.__config__["d_ff"],
            self.__config__["dropout"],
            self.__config__["device"],
        )

    def __desc_date__(self, ts):
        return datetime.fromtimestamp(ts).strftime("%Y年%m月%d日%H:%M:%S")

    def __metric__(
        self,
        y_true,
        y_pred,
        method: Literal["micro", "macro", "samples", "weighted"] = "weighted",
    ):
        precision = precision_score(
            y_true,
            y_pred,
            average=method,
            zero_division=0,
        )
        recall = recall_score(y_true, y_pred, average=method)
        f1score = f1_score(y_true, y_pred, average=method)
        print(f"precision:{precision}")
        print(f"recall   :{recall}")
        print(f"f1score  :{f1score}")
        return [precision, recall, f1score]

    def __call__(self, log_df: pd.DataFrame, anomaly_df: pd.DataFrame):
        r"""
        :log_df       : [timestamp, message]
        :anomaly_df   : [st_time, ed_time], when there has an anomal
        """
        log_columns = log_df.columns.tolist()
        assert "timestamp" in log_columns, "log_df requires `timestamp`"
        assert "message" in log_columns, "log_df requires `message`"
        log_df = log_df.sort_values(by="timestamp")
        log_df["label"] = 0
        anomaly_columns = anomaly_df.columns.tolist()
        assert "st_time" in anomaly_columns, "anomaly_df requires `st_time`"
        assert "ed_time" in anomaly_columns, "anomaly_df requires `ed_time`"

        for _, case in anomaly_df.iterrows():
            log_df.loc[
                (log_df["timestamp"] >= case["st_time"])
                & (log_df["timestamp"] <= case["ed_time"]),
                "label",
            ] = 1
        st_time = log_df.head(1)["timestamp"].item()
        ed_time = log_df.tail(1)["timestamp"].item()
        wins = range(
            int(st_time), int(ed_time) - self.__win_size__ + 1, self.__win_size__
        )
        wins_labels = wins[:-1]
        log_df["scope"] = pd.cut(
            log_df["timestamp"], wins, right=False, labels=wins_labels
        )
        process_bar = tqdm(
            total=len(wins_labels),
            desc=f"process {self.__desc_date__(st_time)} ~ {self.__desc_date__(ed_time)}",
        )
        for wins_label, group in log_df.groupby(by="scope", observed=False):
            label = 1 if len(group[group["label"] > 1]) > 0 else 0
            templates = self.__dataset__.add_sample(group["message"].tolist(), label)
            group["template"] = templates
            self.__wins_labels__.append(wins_label)
            process_bar.update(1)

    def train(self):
        model = self.__create_model__()
        train_loader, eval_loader = self.__dataset__.get_loader(is_training=True)
        optim = torch.optim.Adam(model.parameters(), lr=self.__config__["lr"])
        model.to(self.__config__["device"])
        model.train()
        entropy = nn.CrossEntropyLoss()
        best_loss = torch.inf
        for epoch in range(self.__config__["epochs"]):
            avg_loss = 0
            for inputs, labels in tqdm(train_loader, desc=f"Training {epoch:3}"):
                inputs = inputs.to(device=self.__config__["device"])
                labels = labels.to(device=self.__config__["device"])
                optim.zero_grad()
                outputs = model(inputs)
                loss = entropy(outputs, labels)
                loss.backward()
                avg_loss += loss.item()
                optim.step()
            avg_loss /= len(train_loader)
            print(f"epoch: {epoch: 3}, training loss: {avg_loss:.6f}")
            best_loss = self.__eval__(model, eval_loader, best_loss)

    def __eval__(self, model, eval_loader, best_loss):
        model.eval()
        with torch.no_grad():
            avg_loss = 0
            for inputs, labels in tqdm(eval_loader, desc="Evaling"):
                inputs = inputs.to(device=self.__config__["device"])
                labels = labels.to(device=self.__config__["device"])
                outputs = model(inputs)
                avg_loss += torch.nn.functional.cross_entropy(outputs, labels).item()
            avg_loss /= len(eval_loader)
            if avg_loss <= best_loss:
                print(f"reduce from {best_loss:.6f} -> {avg_loss:.6f}")
                torch.save(
                    model.state_dict(),
                    os.path.join(self.__config__["result_dir"], "model.pth"),
                )
        return avg_loss

    def test(self, model_path):

        model = self.__create_model__()
        model.load_state_dict(torch.load(model_path))
        _, test_loader = self.__dataset__.get_loader(is_training=False)
        model.to(self.__config__["device"])
        model.eval()
        all_labels = []
        all_preds = []
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(device=self.__config__["device"])
            outputs = model(inputs)
            all_labels.extend(labels.tolist())
            all_preds.extend(outputs.argmax(dim=1).tolist())

        return {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
            "predict": list(
                zip(
                    self.__wins_labels__,
                    [
                        start_time + self.__win_size__
                        for start_time in self.__wins_labels__
                    ],
                    all_preds,
                    all_labels
                )
            ),
            "metric": self.__metric__(all_labels, all_preds, "macro"),
        }
